""" Graph Average Exploration - Figure 4"""

import os
import csv
import pandas as pd
import matplotlib
matplotlib.use('agg')
import matplotlib.style
matplotlib.style.use('seaborn-white')
matplotlib.rc('text', usetex=False)
matplotlib.rc('font', family='CMU Serif')
matplotlib.rc('font', serif='CMU Serif')
matplotlib.rc('font', size=13)
import matplotlib.pyplot as plt

DIR_DATA = 'sims/'
DIR_RESULTS = 'results/'

YLIM_EXPLORATION     = 15
YTICKS_EXPLORATION   = [0, 3, 6, 9, 12, 15]
XTICKS_EXPLORATION   = [0, 100, 200, 300, 400, 500]
PLOT_FONT_SIZE       = 15
PLOT_AXIS_FONT_SIZE  = 15
PLOT_LINE_WIDTH      = 1.25
PLOT_AXIS_LABEL_PAD  = 10

def load_csv(fname, quote=None, must_exist=True):
    out = list()
    fname = str(fname.strip())
    if not os.path.exists(fname):
        if must_exist:
            raise IOError
        else:
            return out
    with open(fname, 'Ur') as f:   # universal newlines support; read-only
        csv_reader = csv.reader(f, dialect='excel', quotechar=quote) if quote else csv.reader(f, dialect='excel')
        for r in csv_reader:
            out.append([value for value in r])
    return out

def plot_comparison(label_a, data_a, label_b, data_b, fname):
    """ Graph comparison of the average exploration for two history """
    plt.close('all')
    x = list(range(len(data_a)))
    fig, ax = plt.subplots()
    fig.set_size_inches(6.5, 5)
    ax.set_ylim(0, 15)
    plt.yticks(YTICKS_EXPLORATION)
    plt.xticks(XTICKS_EXPLORATION)
    if YLIM_EXPLORATION: ax.set_ylim(0,YLIM_EXPLORATION)
    if YTICKS_EXPLORATION: ax.set_yticks(YTICKS_EXPLORATION)
    ax.plot(x, data_a, lw=PLOT_LINE_WIDTH, ls='--', label=label_a, color='black',  alpha=.85)
    ax.plot(x, data_b, lw=PLOT_LINE_WIDTH, ls='-',  label=label_b, color='black', alpha=.85)
    ax.xaxis.grid(True, linestyle='-', color='grey', alpha=.25)
    ax.yaxis.grid(True, linestyle='-', color='grey', alpha=.25)
    ax.set_xlabel('Move Number',labelpad=PLOT_AXIS_LABEL_PAD, fontsize=PLOT_AXIS_FONT_SIZE)
    ax.set_ylabel('Number of Exploratory Actions',labelpad=PLOT_AXIS_LABEL_PAD, fontsize=PLOT_AXIS_FONT_SIZE)
    plt.legend(frameon=1, loc=2)
    frame = ax.legend_.get_frame()
    frame.set_facecolor('white')
    frame.set_edgecolor('white')
    if os.path.exists(fname): os.remove(fname)
    plt.savefig(fname, format='pdf')
    plt.close()


# COMMANDLINE  ===============================================================================================================================================
if __name__ == "__main__":

    csv1  = [ (int(move), float(explore)) for move, explore in load_csv(DIR_DATA + 'sim-config-1.csv')]
    csv3  = [ (int(move), float(explore)) for move, explore in load_csv(DIR_DATA + 'sim-config-3.csv')]
    csv10 = [ (int(move), float(explore)) for move, explore in load_csv(DIR_DATA + 'sim-config-10.csv')]
    csv12 = [ (int(move), float(explore)) for move, explore in load_csv(DIR_DATA + 'sim-config-12.csv')]

    df1  = pd.DataFrame(csv1, columns=('move','explore'))
    df3  = pd.DataFrame(csv3, columns=('move','explore'))
    df10 = pd.DataFrame(csv10, columns=('move','explore'))
    df12 = pd.DataFrame(csv12, columns=('move','explore'))

    df1_avg  = df1.groupby(['move']).mean()
    df3_avg  = df3.groupby(['move']).mean()
    df10_avg = df10.groupby(['move']).mean()
    df12_avg = df12.groupby(['move']).mean()

    plot_comparison(label_a='Losses', data_a=df3_avg, label_b='Gains', data_b=df1_avg, fname=DIR_RESULTS + '10_Graph_Figure7a.pdf')
    plot_comparison(label_a='Losses', data_a=df12_avg, label_b='Gains', data_b=df10_avg, fname=DIR_RESULTS + '10_Graph_Figure7b.pdf')

